{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import tensorflow.contrib.eager as tfe\n",
    "tf.enable_eager_execution()\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 数据（国内）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import subprocess\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "#下载 mnist 数据集\n",
    "commands = ['wget https://devlab-1251520893.cos.ap-guangzhou.myqcloud.com/t10k-images-idx3-ubyte.gz',\n",
    "            'wget https://devlab-1251520893.cos.ap-guangzhou.myqcloud.com/t10k-labels-idx1-ubyte.gz',\n",
    "            'wget https://devlab-1251520893.cos.ap-guangzhou.myqcloud.com/train-images-idx3-ubyte.gz',\n",
    "            'wget https://devlab-1251520893.cos.ap-guangzhou.myqcloud.com/train-labels-idx1-ubyte.gz']\n",
    "for c in commands:\n",
    "    subprocess.call(c, shell=True)\n",
    "\n",
    "mnist = input_data.read_data_sets('./', one_hot=False)\n",
    "\n",
    "train_x = mnist.train.images.astype('float32')\n",
    "train_y = mnist.train.labels.astype('int32')\n",
    "test_x = mnist.test.images.astype('float32')\n",
    "test_y = mnist.test.labels.astype('int32')\n",
    "\n",
    "#print('数字形式的图片：%s' %train_x[0].reshape([28,28]))\n",
    "print('图片数字：%s' %train_y[0])\n",
    "print('训练数据量 %s, %s' %(train_x.shape,train_y.shape))\n",
    "print('测试数据量 %s, %s' %(test_x.shape,test_y.shape))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 数据（国外用户）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Populating the interactive namespace from numpy and matplotlib\n",
      "图片数字：5\n",
      "训练数据量 (60000, 784), (60000,)\n",
      "测试数据量 (10000, 784), (10000,)\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAADgpJREFUeJzt3X+MVfWZx/HPs1j+kKI4aQRCYSnEYJW4082IjSWrxkzVDQZHrekkJjQapn8wiU02ZA3/VNNgyCrslmiamaZYSFpKE3VB0iw0otLGZuKIWC0srTFsO3IDNTjywx9kmGf/mEMzxbnfe+fec++5zPN+JeT+eM6558kNnznn3O+592vuLgDx/EPRDQAoBuEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxDUZc3cmJlxOSHQYO5u1SxX157fzO40syNm9q6ZPVrPawFoLqv12n4zmybpj5I6JQ1Jel1St7sfSqzDnh9osGbs+ZdJetfd33P3c5J+IWllHa8HoInqCf88SX8Z93goe+7vmFmPmQ2a2WAd2wKQs3o+8Jvo0OJzh/Xu3i+pX+KwH2gl9ez5hyTNH/f4y5KO1dcOgGapJ/yvS7rGzL5iZtMlfVvSrnzaAtBoNR/2u/uImfVK2iNpmqQt7v6H3DoD0FA1D/XVtDHO+YGGa8pFPgAuXYQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8EVfMU3ZJkZkclnZZ0XtKIu3fk0RTyM23atGT9yiuvbOj2e3t7y9Yuv/zy5LpLlixJ1tesWZOsP/XUU2Vr3d3dyXU//fTTZH3Dhg3J+uOPP56st4K6wp+5zd0/yOF1ADQRh/1AUPWG3yXtNbM3zKwnj4YANEe9h/3fcPdjZna1pF+b2f+6+/7xC2R/FPjDALSYuvb87n4suz0h6QVJyyZYpt/dO/gwEGgtNYffzGaY2cwL9yV9U9I7eTUGoLHqOeyfLekFM7vwOj939//JpSsADVdz+N39PUn/lGMvU9aCBQuS9enTpyfrN998c7K+fPnysrVZs2Yl173vvvuS9SINDQ0l65s3b07Wu7q6ytZOnz6dXPett95K1l999dVk/VLAUB8QFOEHgiL8QFCEHwiK8ANBEX4gKHP35m3MrHkba6L29vZkfd++fcl6o79W26pGR0eT9YceeihZP3PmTM3bLpVKyfqHH36YrB85cqTmbTeau1s1y7HnB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgGOfPQVtbW7I+MDCQrC9atCjPdnJVqffh4eFk/bbbbitbO3fuXHLdqNc/1ItxfgBJhB8IivADQRF+ICjCDwRF+IGgCD8QVB6z9IZ38uTJZH3t2rXJ+ooVK5L1N998M1mv9BPWKQcPHkzWOzs7k/WzZ88m69dff33Z2iOPPJJcF43Fnh8IivADQRF+ICjCDwRF+IGgCD8QFOEHgqr4fX4z2yJphaQT7r40e65N0g5JCyUdlfSAu6d/6FxT9/v89briiiuS9UrTSff19ZWtPfzww8l1H3zwwWR9+/btyTpaT57f5/+ppDsveu5RSS+5+zWSXsoeA7iEVAy/u++XdPElbCslbc3ub5V0T859AWiwWs/5Z7t7SZKy26vzawlAMzT82n4z65HU0+jtAJicWvf8x81sriRltyfKLeju/e7e4e4dNW4LQAPUGv5dklZl91dJ2plPOwCapWL4zWy7pN9JWmJmQ2b2sKQNkjrN7E+SOrPHAC4hFc/53b27TOn2nHsJ69SpU3Wt/9FHH9W87urVq5P1HTt2JOujo6M1bxvF4go/ICjCDwRF+IGgCD8QFOEHgiL8QFBM0T0FzJgxo2ztxRdfTK57yy23JOt33XVXsr53795kHc3HFN0Akgg/EBThB4Ii/EBQhB8IivADQRF+ICjG+ae4xYsXJ+sHDhxI1oeHh5P1l19+OVkfHBwsW3vmmWeS6zbz/+ZUwjg/gCTCDwRF+IGgCD8QFOEHgiL8QFCEHwiKcf7gurq6kvVnn302WZ85c2bN2163bl2yvm3btmS9VCrVvO2pjHF+AEmEHwiK8ANBEX4gKMIPBEX4gaAIPxBUxXF+M9siaYWkE+6+NHvuMUmrJf01W2ydu/+q4sYY57/kLF26NFnftGlTsn777bXP5N7X15esr1+/Pll///33a972pSzPcf6fSrpzguf/093bs38Vgw+gtVQMv7vvl3SyCb0AaKJ6zvl7zez3ZrbFzK7KrSMATVFr+H8kabGkdkklSRvLLWhmPWY2aGblf8wNQNPVFH53P+7u5919VNKPJS1LLNvv7h3u3lFrkwDyV1P4zWzuuIddkt7Jpx0AzXJZpQXMbLukWyV9ycyGJH1f0q1m1i7JJR2V9N0G9gigAfg+P+oya9asZP3uu+8uW6v0WwFm6eHqffv2JeudnZ3J+lTF9/kBJBF+ICjCDwRF+IGgCD8QFOEHgmKoD4X57LPPkvXLLktfhjIyMpKs33HHHWVrr7zySnLdSxlDfQCSCD8QFOEHgiL8QFCEHwiK8ANBEX4gqIrf50dsN9xwQ7J+//33J+s33nhj2VqlcfxKDh06lKzv37+/rtef6tjzA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQjPNPcUuWLEnWe3t7k/V77703WZ8zZ86ke6rW+fPnk/VSqZSsj46O5tnOlMOeHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCqjjOb2bzJW2TNEfSqKR+d/+hmbVJ2iFpoaSjkh5w9w8b12pclcbSu7u7y9YqjeMvXLiwlpZyMTg4mKyvX78+Wd+1a1ee7YRTzZ5/RNK/uftXJX1d0hozu07So5JecvdrJL2UPQZwiagYfncvufuB7P5pSYclzZO0UtLWbLGtku5pVJMA8jepc34zWyjpa5IGJM1295I09gdC0tV5Nwegcaq+tt/MvijpOUnfc/dTZlVNByYz65HUU1t7ABqlqj2/mX1BY8H/mbs/nz193MzmZvW5kk5MtK6797t7h7t35NEwgHxUDL+N7eJ/Iumwu28aV9olaVV2f5Wknfm3B6BRKk7RbWbLJf1G0tsaG+qTpHUaO+//paQFkv4s6VvufrLCa4Wconv27NnJ+nXXXZesP/3008n6tddeO+me8jIwMJCsP/nkk2VrO3em9xd8Jbc21U7RXfGc391/K6nci90+maYAtA6u8AOCIvxAUIQfCIrwA0ERfiAowg8ExU93V6mtra1sra+vL7lue3t7sr5o0aKaesrDa6+9lqxv3LgxWd+zZ0+y/sknn0y6JzQHe34gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCCrMOP9NN92UrK9duzZZX7ZsWdnavHnzauopLx9//HHZ2ubNm5PrPvHEE8n62bNna+oJrY89PxAU4QeCIvxAUIQfCIrwA0ERfiAowg8EFWacv6urq656PQ4dOpSs7969O1kfGRlJ1lPfuR8eHk6ui7jY8wNBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUObu6QXM5kvaJmmOpFFJ/e7+QzN7TNJqSX/NFl3n7r+q8FrpjQGom7tbNctVE/65kua6+wEzmynpDUn3SHpA0hl3f6rapgg/0HjVhr/iFX7uXpJUyu6fNrPDkor96RoAdZvUOb+ZLZT0NUkD2VO9ZvZ7M9tiZleVWafHzAbNbLCuTgHkquJh/98WNPuipFclrXf3581stqQPJLmkH2js1OChCq/BYT/QYLmd80uSmX1B0m5Je9x90wT1hZJ2u/vSCq9D+IEGqzb8FQ/7zcwk/UTS4fHBzz4IvKBL0juTbRJAcar5tH+5pN9IeltjQ32StE5St6R2jR32H5X03ezDwdRrsecHGizXw/68EH6g8XI77AcwNRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCavYU3R9I+r9xj7+UPdeKWrW3Vu1Lorda5dnbP1a7YFO/z/+5jZsNuntHYQ0ktGpvrdqXRG+1Kqo3DvuBoAg/EFTR4e8vePsprdpbq/Yl0VutCumt0HN+AMUpes8PoCCFhN/M7jSzI2b2rpk9WkQP5ZjZUTN728wOFj3FWDYN2gkze2fcc21m9msz+1N2O+E0aQX19piZvZ+9dwfN7F8L6m2+mb1sZofN7A9m9kj2fKHvXaKvQt63ph/2m9k0SX+U1ClpSNLrkrrd/VBTGynDzI5K6nD3wseEzexfJJ2RtO3CbEhm9h+STrr7huwP51Xu/u8t0ttjmuTMzQ3qrdzM0t9Rge9dnjNe56GIPf8ySe+6+3vufk7SLyStLKCPlufu+yWdvOjplZK2Zve3auw/T9OV6a0luHvJ3Q9k909LujCzdKHvXaKvQhQR/nmS/jLu8ZBaa8pvl7TXzN4ws56im5nA7AszI2W3Vxfcz8UqztzcTBfNLN0y710tM17nrYjwTzSbSCsNOXzD3f9Z0l2S1mSHt6jOjyQt1tg0biVJG4tsJptZ+jlJ33P3U0X2Mt4EfRXyvhUR/iFJ88c9/rKkYwX0MSF3P5bdnpD0gsZOU1rJ8QuTpGa3Jwru52/c/bi7n3f3UUk/VoHvXTaz9HOSfubuz2dPF/7eTdRXUe9bEeF/XdI1ZvYVM5su6duSdhXQx+eY2YzsgxiZ2QxJ31TrzT68S9Kq7P4qSTsL7OXvtMrMzeVmllbB712rzXhdyEU+2VDGf0maJmmLu69vehMTMLNFGtvbS2PfePx5kb2Z2XZJt2rsW1/HJX1f0n9L+qWkBZL+LOlb7t70D97K9HarJjlzc4N6Kzez9IAKfO/ynPE6l364wg+IiSv8gKAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8E9f/Ex0YKZYOZcwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 训练、测试集\n",
    "(train_x, train_y), (test_x, test_y) = tf.keras.datasets.mnist.load_data()\n",
    "\n",
    "# 数据形状和类型变换\n",
    "train_x = train_x.reshape([-1,28*28]).astype('float32')\n",
    "test_x = test_x.reshape([-1,28*28]).astype('float32')\n",
    "train_y = train_y.astype('int32')\n",
    "test_y = test_y.astype('int32')\n",
    "\n",
    "# 显示训练集\n",
    "%pylab inline\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.image as mpimg\n",
    "imgplot = plt.imshow(train_x[0].reshape([28,28]),cmap='gray')\n",
    "#print('数字形式的图片：%s' %train_x[0].reshape([28,28]))\n",
    "print('图片数字：%s' %train_y[0])\n",
    "print('训练数据量 %s, %s' %(train_x.shape,train_y.shape))\n",
    "print('测试数据量 %s, %s' %(test_x.shape,test_y.shape))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 自搭模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 初始化方法\n",
    "def inits(shape):\n",
    "    return tf.random_uniform(shape,\n",
    "            minval=-np.sqrt(5) * np.sqrt(1.0 / shape[0]),\n",
    "            maxval=np.sqrt(5) * np.sqrt(1.0 / shape[0]))\n",
    "# 定义模型\n",
    "class Model(object):\n",
    "    def __init__(self):\n",
    "        # 参数初始化\n",
    "        self.W1 = tfe.Variable(inits([784,512]))\n",
    "        self.b1 = tfe.Variable(inits([512]))\n",
    "        self.W2 = tfe.Variable(inits([512,256]))\n",
    "        self.b2 = tfe.Variable(inits([256]))\n",
    "        self.W = tfe.Variable(inits([256,10]))\n",
    "        self.b = tfe.Variable(inits([10]))\n",
    "    def __call__(self, x):\n",
    "        # 正向传递\n",
    "        y = tf.nn.relu(tf.matmul(x, self.W1) + self.b1)\n",
    "        y = tf.nn.relu(tf.matmul(y, self.W2) + self.b2)\n",
    "        y = tf.matmul(y, self.W) + self.b\n",
    "        return y\n",
    "# 实例模型\n",
    "model = Model()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000 | train loss: 0.922 | train acc: 0.891 | test loss: 1.114 | test acc: 0.879\n",
      "Epoch: 001 | train loss: 0.558 | train acc: 0.914 | test loss: 0.785 | test acc: 0.896\n",
      "Epoch: 002 | train loss: 0.387 | train acc: 0.929 | test loss: 0.645 | test acc: 0.908\n",
      "Epoch: 003 | train loss: 0.302 | train acc: 0.939 | test loss: 0.579 | test acc: 0.914\n",
      "Epoch: 004 | train loss: 0.256 | train acc: 0.946 | test loss: 0.542 | test acc: 0.917\n",
      "Epoch: 005 | train loss: 0.209 | train acc: 0.953 | test loss: 0.509 | test acc: 0.918\n",
      "Epoch: 006 | train loss: 0.177 | train acc: 0.959 | test loss: 0.483 | test acc: 0.923\n",
      "Epoch: 007 | train loss: 0.155 | train acc: 0.963 | test loss: 0.472 | test acc: 0.922\n",
      "Epoch: 008 | train loss: 0.134 | train acc: 0.969 | test loss: 0.455 | test acc: 0.924\n",
      "Epoch: 009 | train loss: 0.123 | train acc: 0.971 | test loss: 0.446 | test acc: 0.928\n",
      "Final Test Loss: 0.928\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import accuracy_score\n",
    "# 误差函数\n",
    "def loss(logits, label):\n",
    "    loss = tf.losses.sparse_softmax_cross_entropy(labels=label, logits=logits)\n",
    "    return loss\n",
    "\n",
    "# 更新方式\n",
    "def train(model, x, y, learning_rate, batch_size, epoch):\n",
    "    # 更新次数\n",
    "    for e in range(epoch):\n",
    "        # 打乱顺序\n",
    "        r = np.random.permutation(len(x))\n",
    "        x = [x[i] for i in r]\n",
    "        y = [y[i] for i in r]\n",
    "        # 批量更新\n",
    "        for b in range(0,len(x),batch_size):\n",
    "            # 计算梯度\n",
    "            with tf.GradientTape() as tape:\n",
    "                loss_value = loss(model(np.array(x[b:b+batch_size])), np.array(y[b:b+batch_size]))\n",
    "                dW1, db1, dW2, db2, dW, db = tape.gradient(loss_value, \n",
    "                                   [model.W1, model.b1, model.W2, model.b2, model.W, model.b])\n",
    "            # 训练更新\n",
    "            model.W1.assign_sub(dW1 * learning_rate)\n",
    "            model.b1.assign_sub(db1 * learning_rate)\n",
    "            model.W2.assign_sub(dW2 * learning_rate)\n",
    "            model.b2.assign_sub(db2 * learning_rate)\n",
    "            model.W.assign_sub(dW * learning_rate)\n",
    "            model.b.assign_sub(db * learning_rate)\n",
    "        # 显示\n",
    "        train_p = tf.concat([model(train_x[b:b+batch_size]) for b in range(0,len(train_x),batch_size)],axis=0)\n",
    "        # 显示\n",
    "        test_p = model(test_x)\n",
    "        print(\"Epoch: %03d | train loss: %.3f | train acc: %.3f | test loss: %.3f | test acc: %.3f\" \n",
    "              %(e, loss(train_p, train_y), accuracy_score(tf.argmax(train_p,1), train_y),\n",
    "                   loss(test_p, test_y), accuracy_score(tf.argmax(test_p,1), test_y)))\n",
    "\n",
    "# 训练\n",
    "train(model, train_x, train_y, learning_rate = 0.001, batch_size = 256, epoch = 10)\n",
    "\n",
    "# 评估\n",
    "test_p = model(test_x)\n",
    "print(\"Final Test Loss: %s\" %accuracy_score(tf.argmax(test_p,1), test_y))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 用 API 建模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = tf.keras.Sequential([\n",
    "          tf.keras.layers.Dense(512, activation=tf.nn.relu, input_shape=(784,)),\n",
    "          tf.keras.layers.Dense(256, activation=tf.nn.relu),\n",
    "          tf.keras.layers.Dense(10)\n",
    "        ])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 用 API 训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 000 | train loss: 5.330 | train acc: 0.872 | test loss: 0.901 | test acc: 0.878\n",
      "Epoch: 001 | train loss: 0.653 | train acc: 0.929 | test loss: 0.653 | test acc: 0.899\n",
      "Epoch: 002 | train loss: 0.433 | train acc: 0.945 | test loss: 0.556 | test acc: 0.907\n",
      "Epoch: 003 | train loss: 0.325 | train acc: 0.956 | test loss: 0.504 | test acc: 0.912\n",
      "Epoch: 004 | train loss: 0.259 | train acc: 0.963 | test loss: 0.471 | test acc: 0.917\n",
      "Epoch: 005 | train loss: 0.214 | train acc: 0.968 | test loss: 0.449 | test acc: 0.920\n",
      "Epoch: 006 | train loss: 0.181 | train acc: 0.972 | test loss: 0.433 | test acc: 0.923\n",
      "Epoch: 007 | train loss: 0.156 | train acc: 0.976 | test loss: 0.419 | test acc: 0.926\n",
      "Epoch: 008 | train loss: 0.137 | train acc: 0.978 | test loss: 0.408 | test acc: 0.928\n",
      "Epoch: 009 | train loss: 0.121 | train acc: 0.981 | test loss: 0.400 | test acc: 0.928\n",
      "Final Test Loss: 0.928\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import accuracy_score\n",
    "# 误差函数\n",
    "def loss(logits, label):\n",
    "    loss = tf.losses.sparse_softmax_cross_entropy(labels=label, logits=logits)\n",
    "    return loss\n",
    "\n",
    "# 更新方式\n",
    "def train(model, x, y, learning_rate, batch_size, epoch):\n",
    "    # 更新次数\n",
    "    for e in range(epoch):\n",
    "        # 创建用于记录训练集误差和准确率的方法\n",
    "        epoch_loss_avg = tfe.metrics.Mean()\n",
    "        epoch_accuracy = tfe.metrics.Accuracy()\n",
    "        # 批量更新\n",
    "        for b in range(0,len(x),batch_size):\n",
    "            # 计算梯度\n",
    "            with tf.GradientTape() as tape:\n",
    "                loss_value = loss(model(np.array(x[b:b+batch_size])), np.array(y[b:b+batch_size]))\n",
    "                grads = tape.gradient(loss_value, model.variables)\n",
    "            # 训练更新\n",
    "            optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)\n",
    "            optimizer.apply_gradients(zip(grads, model.variables),\n",
    "                            global_step=tf.train.get_or_create_global_step())\n",
    "            # 因为记忆量限制，记录训练集每个 batch 的误差和准确率\n",
    "            epoch_loss_avg(loss_value)\n",
    "            epoch_accuracy(tf.argmax(model(np.array(x[b:b+batch_size])), axis=1, output_type=tf.int32), np.array(y[b:b+batch_size]))\n",
    "        # 显示\n",
    "        test_p = model(test_x)\n",
    "        print(\"Epoch: %03d | train loss: %.3f | train acc: %.3f | test loss: %.3f | test acc: %.3f\" \n",
    "              %(e, epoch_loss_avg.result(), epoch_accuracy.result(),\n",
    "                   loss(test_p, test_y), accuracy_score(tf.argmax(test_p,1), test_y)))\n",
    "\n",
    "# 训练\n",
    "r = np.random.permutation(len(train_y))\n",
    "# 送入打乱顺序的 train_x 和 train_y\n",
    "train(model, [train_x[i] for i in r], [train_y[i] for i in r], learning_rate = 0.001, batch_size = 256, epoch = 10)\n",
    "\n",
    "# 评估\n",
    "test_p = model(test_x)\n",
    "print(\"Final Test Loss: %s\" %accuracy_score(tf.argmax(test_p,1), test_y))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "真实数字：2\n",
      "预测数字：2\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAADYNJREFUeJzt3X+oXPWZx/HPZ20CYouaFLMXYzc16rIqauUqiy2LSzW6S0wMWE3wjyy77O0fFbYYfxGECEuwLNvu7l+BFC9NtLVpuDHGWjYtsmoWTPAqGk2TtkauaTbX3A0pNkGkJnn2j3uy3MY7ZyYzZ+bMzfN+QZiZ88w552HI555z5pw5X0eEAOTzJ3U3AKAehB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKf6+XKbHM5IdBlEeFW3tfRlt/2nbZ/Zfs92491siwAveV2r+23fZ6kX0u6XdJBSa9LWhERvyyZhy0/0GW92PLfLOm9iHg/Iv4g6ceSlnawPAA91En4L5X02ymvDxbT/ojtIdujtkc7WBeAinXyhd90uxaf2a2PiPWS1kvs9gP9pJMt/0FJl015PV/Soc7aAdArnYT/dUlX2v6y7dmSlkvaVk1bALqt7d3+iDhh+wFJ2yWdJ2k4IvZU1hmArmr7VF9bK+OYH+i6nlzkA2DmIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSRF+IKme3rob7XnooYdK6+eff37D2nXXXVc67z333NNWT6etW7eutP7aa681rD399NMdrRudYcsPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0lx994+sGnTptJ6p+fi67R///6Gtdtuu6103gMHDlTdTgrcvRdAKcIPJEX4gaQIP5AU4QeSIvxAUoQfSKqj3/PbHpN0TNJJSSciYrCKps41dZ7H37dvX2l9+/btpfXLL7+8tH7XXXeV1hcuXNiwdv/995fO++STT5bW0Zkqbubx1xFxpILlAOghdvuBpDoNf0j6ue03bA9V0RCA3uh0t/+rEXHI9iWSfmF7X0S8OvUNxR8F/jAAfaajLX9EHCoeJyQ9J+nmad6zPiIG+TIQ6C9th9/2Bba/cPq5pEWS3q2qMQDd1clu/zxJz9k+vZwfRcR/VtIVgK5rO/wR8b6k6yvsZcYaHCw/olm2bFlHy9+zZ09pfcmSJQ1rR46Un4U9fvx4aX327Nml9Z07d5bWr7++8X+RuXPnls6L7uJUH5AU4QeSIvxAUoQfSIrwA0kRfiAphuiuwMDAQGm9uBaioWan8u64447S+vj4eGm9E6tWrSqtX3311W0v+8UXX2x7XnSOLT+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJMV5/gq88MILpfUrrriitH7s2LHS+tGjR8+6p6osX768tD5r1qwedYKqseUHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQ4z98DH3zwQd0tNPTwww+X1q+66qqOlr9r1662aug+tvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kJQjovwN9rCkxZImIuLaYtocSZskLZA0JuneiPhd05XZ5StD5RYvXlxa37x5c2m92RDdExMTpfWy+wG88sorpfOiPRFRPlBEoZUt/w8k3XnGtMckvRQRV0p6qXgNYAZpGv6IeFXSmbeSWSppQ/F8g6S7K+4LQJe1e8w/LyLGJal4vKS6lgD0Qtev7bc9JGmo2+sBcHba3fIftj0gScVjw299ImJ9RAxGxGCb6wLQBe2Gf5uklcXzlZKer6YdAL3SNPy2n5X0mqQ/t33Q9j9I+o6k223/RtLtxWsAM0jTY/6IWNGg9PWKe0EXDA6WH201O4/fzKZNm0rrnMvvX1zhByRF+IGkCD+QFOEHkiL8QFKEH0iKW3efA7Zu3dqwtmjRoo6WvXHjxtL6448/3tHyUR+2/EBShB9IivADSRF+ICnCDyRF+IGkCD+QVNNbd1e6Mm7d3ZaBgYHS+ttvv92wNnfu3NJ5jxw5Ulq/5ZZbSuv79+8vraP3qrx1N4BzEOEHkiL8QFKEH0iK8ANJEX4gKcIPJMXv+WeAkZGR0nqzc/llnnnmmdI65/HPXWz5gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiCppuf5bQ9LWixpIiKuLaY9IekfJf1v8bbVEfGzbjV5rluyZElp/cYbb2x72S+//HJpfc2aNW0vGzNbK1v+H0i6c5rp/xYRNxT/CD4wwzQNf0S8KuloD3oB0EOdHPM/YHu37WHbF1fWEYCeaDf86yQtlHSDpHFJ3230RttDtkdtj7a5LgBd0Fb4I+JwRJyMiFOSvi/p5pL3ro+IwYgYbLdJANVrK/y2p95Odpmkd6tpB0CvtHKq71lJt0r6ou2DktZIutX2DZJC0pikb3axRwBd0DT8EbFimslPdaGXc1az39uvXr26tD5r1qy21/3WW2+V1o8fP972sjGzcYUfkBThB5Ii/EBShB9IivADSRF+IClu3d0Dq1atKq3fdNNNHS1/69atDWv8ZBeNsOUHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQcEb1bmd27lfWRTz75pLTeyU92JWn+/PkNa+Pj4x0tGzNPRLiV97HlB5Ii/EBShB9IivADSRF+ICnCDyRF+IGk+D3/OWDOnDkNa59++mkPO/msjz76qGGtWW/Nrn+48MIL2+pJki666KLS+oMPPtj2sltx8uTJhrVHH320dN6PP/64kh7Y8gNJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUk3P89u+TNJGSX8q6ZSk9RHxH7bnSNokaYGkMUn3RsTvutcqGtm9e3fdLTS0efPmhrVm9xqYN29eaf2+++5rq6d+9+GHH5bW165dW8l6Wtnyn5C0KiL+QtJfSvqW7aslPSbppYi4UtJLxWsAM0TT8EfEeES8WTw/JmmvpEslLZW0oXjbBkl3d6tJANU7q2N+2wskfUXSLknzImJcmvwDIemSqpsD0D0tX9tv+/OSRiR9OyJ+b7d0mzDZHpI01F57ALqlpS2/7VmaDP4PI2JLMfmw7YGiPiBpYrp5I2J9RAxGxGAVDQOoRtPwe3IT/5SkvRHxvSmlbZJWFs9XSnq++vYAdEvTW3fb/pqkHZLe0eSpPklarcnj/p9I+pKkA5K+ERFHmywr5a27t2zZUlpfunRpjzrJ5cSJEw1rp06dalhrxbZt20rro6OjbS97x44dpfWdO3eW1lu9dXfTY/6I+G9JjRb29VZWAqD/cIUfkBThB5Ii/EBShB9IivADSRF+ICmG6O4DjzzySGm90yG8y1xzzTWl9W7+bHZ4eLi0PjY21tHyR0ZGGtb27dvX0bL7GUN0AyhF+IGkCD+QFOEHkiL8QFKEH0iK8ANJcZ4fOMdwnh9AKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9Iqmn4bV9m+79s77W9x/Y/FdOfsP0/tt8q/v1t99sFUJWmN/OwPSBpICLetP0FSW9IulvSvZKOR8S/trwybuYBdF2rN/P4XAsLGpc0Xjw/ZnuvpEs7aw9A3c7qmN/2AklfkbSrmPSA7d22h21f3GCeIdujtkc76hRApVq+h5/tz0t6RdLaiNhie56kI5JC0j9r8tDg75ssg91+oMta3e1vKfy2Z0n6qaTtEfG9aeoLJP00Iq5tshzCD3RZZTfwtG1JT0naOzX4xReBpy2T9O7ZNgmgPq182/81STskvSPpVDF5taQVkm7Q5G7/mKRvFl8Oli2LLT/QZZXu9leF8APdx337AZQi/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJNX0Bp4VOyLpgymvv1hM60f92lu/9iXRW7uq7O3PWn1jT3/P/5mV26MRMVhbAyX6tbd+7Uuit3bV1Ru7/UBShB9Iqu7wr695/WX6tbd+7Uuit3bV0lutx/wA6lP3lh9ATWoJv+07bf/K9nu2H6ujh0Zsj9l+pxh5uNYhxoph0CZsvztl2hzbv7D9m+Jx2mHSauqtL0ZuLhlZutbPrt9GvO75br/t8yT9WtLtkg5Kel3Sioj4ZU8bacD2mKTBiKj9nLDtv5J0XNLG06Mh2f4XSUcj4jvFH86LI+LRPuntCZ3lyM1d6q3RyNJ/pxo/uypHvK5CHVv+myW9FxHvR8QfJP1Y0tIa+uh7EfGqpKNnTF4qaUPxfIMm//P0XIPe+kJEjEfEm8XzY5JOjyxd62dX0lct6gj/pZJ+O+X1QfXXkN8h6ee237A9VHcz05h3emSk4vGSmvs5U9ORm3vpjJGl++aza2fE66rVEf7pRhPpp1MOX42IGyX9jaRvFbu3aM06SQs1OYzbuKTv1tlMMbL0iKRvR8Tv6+xlqmn6quVzqyP8ByVdNuX1fEmHauhjWhFxqHickPScJg9T+snh04OkFo8TNffz/yLicEScjIhTkr6vGj+7YmTpEUk/jIgtxeTaP7vp+qrrc6sj/K9LutL2l23PlrRc0rYa+vgM2xcUX8TI9gWSFqn/Rh/eJmll8XylpOdr7OWP9MvIzY1GllbNn12/jXhdy0U+xamMf5d0nqThiFjb8yamYftyTW7tpclfPP6ozt5sPyvpVk3+6uuwpDWStkr6iaQvSTog6RsR0fMv3hr0dqvOcuTmLvXWaGTpXarxs6tyxOtK+uEKPyAnrvADkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5DU/wG6SwYLYCwMKQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "imgplot = plt.imshow(test_x[1].reshape((28,28)),cmap='gray')\n",
    "print('真实数字：%s' %test_y[1])\n",
    "print('预测数字：%s' %tf.argmax(test_p[1]).numpy())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
